#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import contextlib
import multiprocessing
import multiprocessing.managers
import os
import platform
import random
import socket
import signal
import subprocess
import threading
import time
import traceback

from crossrunner.test import TestEntry, domain_socket_path
from crossrunner.report import ExecReporter, SummaryReporter

RESULT_TIMEOUT = 128
RESULT_ERROR = 64


class ExecutionContext(object):
  def __init__(self, cmd, cwd, env, report):
    self._log = multiprocessing.get_logger()
    self.report = report
    self.cmd = cmd
    self.cwd = cwd
    self.env = env
    self.timer = None
    self.expired = False

  def _expire(self):
    self._log.info('Timeout')
    self.expired = True
    self.kill()

  def kill(self):
    self._log.debug('Killing process : %d' % self.proc.pid)
    if platform.system() != 'Windows':
      try:
        os.killpg(self.proc.pid, signal.SIGKILL)
      except Exception as err:
        self._log.info('Failed to kill process group : %s' % str(err))
    try:
      self.proc.kill()
    except Exception as err:
      self._log.info('Failed to kill process : %s' % str(err))
    self.report.killed()

  def _popen_args(self):
    args = {
      'cwd': self.cwd,
      'env': self.env,
      'stdout': self.report.out,
      'stderr': subprocess.STDOUT,
    }
    # make sure child processes doesn't remain after killing
    if platform.system() == 'Windows':
      DETACHED_PROCESS = 0x00000008
      args.update(creationflags=DETACHED_PROCESS | subprocess.CREATE_NEW_PROCESS_GROUP)
    else:
      args.update(preexec_fn=os.setsid)
    return args

  def start(self, timeout=0):
    self._log.debug('COMMAND: %s', ' '.join(self.cmd))
    self._log.debug('WORKDIR: %s', self.cwd)
    self._log.debug('LOGFILE: %s', self.report.logpath)
    self.report.begin()
    self.proc = subprocess.Popen(self.cmd, **self._popen_args())
    if timeout > 0:
      self.timer = threading.Timer(timeout, self._expire)
      self.timer.start()
    return self._scoped()

  @contextlib.contextmanager
  def _scoped(self):
    yield self
    self._log.debug('Killing scoped process')
    self.kill()

  def wait(self):
    self.proc.communicate()
    if self.timer:
      self.timer.cancel()
    self.report.end(self.returncode)

  @property
  def returncode(self):
    return self.proc.returncode if self.proc else None


def exec_context(port, testdir, test, prog):
  report = ExecReporter(testdir, test, prog)
  prog.build_command(port)
  return ExecutionContext(prog.command, prog.workdir, prog.env, report)


def run_test(testdir, test_dict, async=True, max_retry=3):
  try:
    logger = multiprocessing.get_logger()
    retry_count = 0
    test = TestEntry(testdir, **test_dict)
    while True:
      if stop.is_set():
        logger.debug('Skipping because shutting down')
        return None
      logger.debug('Start')
      with PortAllocator.alloc_port_scoped(ports, test.socket) as port:
        logger.debug('Start with port %d' % port)
        sv = exec_context(port, testdir, test, test.server)
        cl = exec_context(port, testdir, test, test.client)

        logger.debug('Starting server')
        with sv.start():
          if test.delay > 0:
            logger.debug('Delaying client for %.2f seconds' % test.delay)
            time.sleep(test.delay)
          cl_retry_count = 0
          cl_max_retry = 10
          cl_retry_wait = 0.5
          while True:
            logger.debug('Starting client')
            cl.start(test.timeout)
            logger.debug('Waiting client')
            cl.wait()
            if not cl.report.maybe_false_positive() or cl_retry_count >= cl_max_retry:
              if cl_retry_count > 0 and cl_retry_count < cl_max_retry:
                logger.warn('[%s]: Connected after %d retry (%.2f sec each)' % (test.server.name, cl_retry_count, cl_retry_wait))
              break
            logger.debug('Server may not be ready, waiting %.2f second...' % cl_retry_wait)
            time.sleep(cl_retry_wait)
            cl_retry_count += 1

      if not sv.report.maybe_false_positive() or retry_count >= max_retry:
        logger.debug('Finish')
        return RESULT_TIMEOUT if cl.expired else cl.proc.returncode
      logger.warn('[%s]: Detected socket bind failure, retrying...' % test.server.name)
      retry_count += 1
  except (KeyboardInterrupt, SystemExit):
    logger.info('Interrupted execution')
    if not async:
      raise
    stop.set()
    return None
  except Exception as ex:
    logger.warn('Error while executing test : %s' % str(ex))
    if not async:
      raise
    logger.info(traceback.print_exc())
    return RESULT_ERROR


class PortAllocator(object):
  def __init__(self):
    self._log = multiprocessing.get_logger()
    self._lock = multiprocessing.Lock()
    self._ports = set()
    self._dom_ports = set()
    self._last_alloc = 0

  def _get_tcp_port(self):
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.bind(('127.0.0.1', 0))
    port = sock.getsockname()[1]
    self._lock.acquire()
    try:
      ok = port not in self._ports
      if ok:
        self._ports.add(port)
        self._last_alloc = time.time()
    finally:
      self._lock.release()
      sock.close()
    return port if ok else self._get_tcp_port()

  def _get_domain_port(self):
    port = random.randint(1024, 65536)
    self._lock.acquire()
    try:
      ok = port not in self._dom_ports
      if ok:
        self._dom_ports.add(port)
    finally:
      self._lock.release()
    return port if ok else self._get_domain_port()

  def alloc_port(self, socket_type):
    if socket_type in ('domain', 'abstract'):
      return self._get_domain_port()
    else:
      return self._get_tcp_port()

  # static method for inter-process invokation
  @staticmethod
  @contextlib.contextmanager
  def alloc_port_scoped(allocator, socket_type):
    port = allocator.alloc_port(socket_type)
    yield port
    allocator.free_port(socket_type, port)

  def free_port(self, socket_type, port):
    self._log.debug('free_port')
    self._lock.acquire()
    try:
      if socket_type == 'domain':
        self._dom_ports.remove(port)
        path = domain_socket_path(port)
        if os.path.exists(path):
          os.remove(path)
      elif socket_type == 'abstract':
        self._dom_ports.remove(port)
      else:
        self._ports.remove(port)
    except IOError as err:
      self._log.info('Error while freeing port : %s' % str(err))
    finally:
      self._lock.release()


class NonAsyncResult(object):
  def __init__(self, value):
    self._value = value

  def get(self, timeout=None):
    return self._value

  def wait(self, timeout=None):
    pass

  def ready(self):
    return True

  def successful(self):
    return self._value == 0


class TestDispatcher(object):
  def __init__(self, testdir, concurrency):
    self._log = multiprocessing.get_logger()
    self.testdir = testdir
    # seems needed for python 2.x to handle keyboard interrupt
    self._stop = multiprocessing.Event()
    self._async = concurrency > 1
    if not self._async:
      self._pool = None
      global stop
      global ports
      stop = self._stop
      ports = PortAllocator()
    else:
      self._m = multiprocessing.managers.BaseManager()
      self._m.register('ports', PortAllocator)
      self._m.start()
      self._pool = multiprocessing.Pool(concurrency, self._pool_init, (self._m.address,))
    self._report = SummaryReporter(testdir, concurrency > 1)
    self._log.debug(
        'TestDispatcher started with %d concurrent jobs' % concurrency)

  def _pool_init(self, address):
    global stop
    global m
    global ports
    stop = self._stop
    m = multiprocessing.managers.BaseManager(address)
    m.connect()
    ports = m.ports()

  def _dispatch_sync(self, test, cont):
    r = run_test(self.testdir, test, False)
    cont(r)
    return NonAsyncResult(r)

  def _dispatch_async(self, test, cont):
    return self._pool.apply_async(func=run_test, args=(self.testdir, test,), callback=cont)

  def dispatch(self, test):
    index = self._report.add_test(test)

    def cont(r):
      if not self._stop.is_set():
        self._log.debug('freeing port')
        self._log.debug('adding result')
        self._report.add_result(index, r, r == RESULT_TIMEOUT)
        self._log.debug('finish continuation')
    fn = self._dispatch_async if self._async else self._dispatch_sync
    return fn(test, cont)

  def wait(self):
    if self._async:
      self._pool.close()
      self._pool.join()
      self._m.shutdown()
    return self._report.end()

  def terminate(self):
    self._stop.set()
    if self._async:
      self._pool.terminate()
      self._pool.join()
      self._m.shutdown()
